{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import ray\n",
    "from ray import tune\n",
    "\n",
    "import torch # to remove later\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "import models\n",
    "import networks\n",
    "from datasets import PreprocessedSpeechDataLoader, VaryingDataLoader\n",
    "from nupic.research.frameworks.pytorch.image_transforms import RandomNoise\n",
    "\n",
    "from torchsummary import summary\n",
    "\n",
    "import math\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "from torchvision import models\n",
    "from nupic.torch.modules import Flatten, KWinners, KWinners2d\n",
    "from networks_module.layers import DSConv2d, RandDSConv2d, SparseConv2d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = dict(\n",
    "    device=(\"cuda\" if torch.cuda.device_count() > 0 else \"cpu\"),\n",
    "    dataset_name=\"PreprocessedGSC\",\n",
    "    data_dir=\"~/nta/datasets/gsc\",\n",
    "    batch_size_train=(4, 16),\n",
    "    batch_size_test=1000,\n",
    "\n",
    "    # ----- Network Related ------\n",
    "    # SE\n",
    "    # model=tune.grid_search([\"BaseModel\", \"SparseModel\", \"DSNNMixedHeb\", \"DSNNConvHeb\"]),\n",
    "    model=\"DSNNConvHeb\",\n",
    "    network=\"gsc_conv_heb\",\n",
    "\n",
    "    # ----- Optimizer Related ----\n",
    "    optim_alg=\"SGD\",\n",
    "    momentum=0,\n",
    "    learning_rate=0.01,\n",
    "    weight_decay=0.01,\n",
    "    lr_scheduler=\"StepLR\",\n",
    "    lr_gamma=0.90,\n",
    "    use_kwinners = True,\n",
    "    # use_kwinners=tune.grid_search([True, False]),\n",
    "\n",
    "    # ----- Dynamic-Sparse Related  - FC LAYER -----\n",
    "    epsilon=184.61538/3, # 0.1 in the 1600-1000 linear layer\n",
    "    sparse_linear_only = True,\n",
    "    start_sparse=1,\n",
    "    end_sparse=-1, # don't get last layer\n",
    "    weight_prune_perc=0.15,\n",
    "    hebbian_prune_perc=0.60,\n",
    "    pruning_es=True,\n",
    "    pruning_es_patience=0,\n",
    "    pruning_es_window_size=5,\n",
    "    pruning_es_threshold=0.02,\n",
    "    pruning_interval=1,\n",
    "\n",
    "    # ----- Dynamic-Sparse Related  - CONV -----\n",
    "    prune_methods='dynamic',\n",
    "    hebbian_prune_frac=0.99,\n",
    "    magnitude_prune_frac=0.0,\n",
    "    sparsity=0.98,\n",
    "    update_nsteps=50,\n",
    "    prune_dims=tuple(),\n",
    "\n",
    "    # ----- Additional Validation -----\n",
    "    test_noise=False,\n",
    "    noise_level=0.1,\n",
    "\n",
    "    # ----- Debugging -----\n",
    "    debug_weights=True,\n",
    "    debug_sparse=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "network = networks.gsc_conv_heb(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "          DSConv2d-1           [-1, 64, 28, 28]           1,664\n",
      "         _NullConv-2           [-1, 25, 28, 28]             625\n",
      "       BatchNorm2d-3           [-1, 64, 28, 28]               0\n",
      "         MaxPool2d-4           [-1, 64, 14, 14]               0\n",
      "        KWinners2d-5           [-1, 64, 14, 14]               0\n",
      "          DSConv2d-6           [-1, 64, 10, 10]         102,464\n",
      "         _NullConv-7         [-1, 1600, 10, 10]       2,560,000\n",
      "       BatchNorm2d-8           [-1, 64, 10, 10]               0\n",
      "         MaxPool2d-9             [-1, 64, 5, 5]               0\n",
      "       KWinners2d-10             [-1, 64, 5, 5]               0\n",
      "          Flatten-11                 [-1, 1600]               0\n",
      "           Linear-12                 [-1, 1000]       1,601,000\n",
      "      BatchNorm1d-13                 [-1, 1000]               0\n",
      "         KWinners-14                 [-1, 1000]               0\n",
      "           Linear-15                   [-1, 12]          12,012\n",
      "================================================================\n",
      "Total params: 4,277,765\n",
      "Trainable params: 1,717,140\n",
      "Non-trainable params: 2,560,625\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.00\n",
      "Forward/backward pass size (MB): 2.48\n",
      "Params size (MB): 16.32\n",
      "Estimated Total Size (MB): 18.81\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "summary(network, input_size=(1, 32, 32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "network(torch.rand(10,1,32,32));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): DSConv2d(\n",
       "    1, 64, kernel_size=(5, 5), stride=(1, 1)\n",
       "    (grouped_conv): _NullConv(25, 25, kernel_size=(5, 5), stride=(1, 1), groups=25, bias=False)\n",
       "  )\n",
       "  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)\n",
       "  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  (3): KWinners2d(channels=64, n=12544, percent_on=0.095, boost_strength=1.5, boost_strength_factor=0.9, k_inference_factor=1.5, duty_cycle_period=1000)\n",
       "  (4): DSConv2d(\n",
       "    64, 64, kernel_size=(5, 5), stride=(1, 1)\n",
       "    (grouped_conv): _NullConv(102400, 1600, kernel_size=(5, 5), stride=(1, 1), groups=1600, bias=False)\n",
       "  )\n",
       "  (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)\n",
       "  (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  (7): KWinners2d(channels=64, n=1600, percent_on=0.125, boost_strength=1.5, boost_strength_factor=0.9, k_inference_factor=1.5, duty_cycle_period=1000)\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "network.features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Flatten()\n",
       "  (1): Linear(in_features=1600, out_features=1000, bias=True)\n",
       "  (2): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)\n",
       "  (3): KWinners(n=1000, percent_on=0.1, boost_strength=1.5, boost_strength_factor=0.9, k_inference_factor=1.5, duty_cycle_period=1000)\n",
       "  (4): Linear(in_features=1000, out_features=12, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "network.classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<bound method Module.named_modules of GSCHeb(\n",
       "  (features): Sequential(\n",
       "    (0): DSConv2d(\n",
       "      1, 64, kernel_size=(5, 5), stride=(1, 1)\n",
       "      (grouped_conv): _NullConv(25, 25, kernel_size=(5, 5), stride=(1, 1), groups=25, bias=False)\n",
       "    )\n",
       "    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)\n",
       "    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (3): KWinners2d(channels=64, n=12544, percent_on=0.095, boost_strength=1.5, boost_strength_factor=0.9, k_inference_factor=1.5, duty_cycle_period=1000)\n",
       "    (4): DSConv2d(\n",
       "      64, 64, kernel_size=(5, 5), stride=(1, 1)\n",
       "      (grouped_conv): _NullConv(102400, 1600, kernel_size=(5, 5), stride=(1, 1), groups=1600, bias=False)\n",
       "    )\n",
       "    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)\n",
       "    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (7): KWinners2d(channels=64, n=1600, percent_on=0.125, boost_strength=1.5, boost_strength_factor=0.9, k_inference_factor=1.5, duty_cycle_period=1000)\n",
       "  )\n",
       "  (classifier): Sequential(\n",
       "    (0): Flatten()\n",
       "    (1): Linear(in_features=1600, out_features=1000, bias=True)\n",
       "    (2): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)\n",
       "    (3): KWinners(n=1000, percent_on=0.1, boost_strength=1.5, boost_strength_factor=0.9, k_inference_factor=1.5, duty_cycle_period=1000)\n",
       "    (4): Linear(in_features=1000, out_features=12, bias=True)\n",
       "  )\n",
       ")>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "network.named_modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "features.0\n",
      "<class 'networks_module.layers.DSConv2d'>\n",
      "features.4\n",
      "<class 'networks_module.layers.DSConv2d'>\n"
     ]
    }
   ],
   "source": [
    "for name, module in network.named_modules():\n",
    "    # if it is a dsconv layer\n",
    "    if isinstance(module, DSConv2d):\n",
    "        print(name)\n",
    "        print(module.__class__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
